{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "91e1016d",
   "metadata": {},
   "source": [
    "# Partitioned Training\n",
    "The following example shows how to train a model with partitioning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "import numpy as np\n",
    "import ampligraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e03ab8",
   "metadata": {},
   "source": [
    "## Train and predict scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_split: memory before: 848.0Bytes, after: 4.3447MB, consumed: 4.3439MB; exec time: 29.242s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/alberto.bernardi/miniforge3/lib/python3.10/site-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
      "  warnings.warn(\n",
      "2023-02-08 16:47:49.873938: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "12/12 [==============================] - 3s 257ms/step - loss: 8690.4297\n",
      "Epoch 2/10\n",
      "12/12 [==============================] - 2s 159ms/step - loss: 8642.6641\n",
      "Epoch 3/10\n",
      "12/12 [==============================] - 2s 173ms/step - loss: 8588.0459\n",
      "Epoch 4/10\n",
      "12/12 [==============================] - 2s 171ms/step - loss: 8524.0557\n",
      "Epoch 5/10\n",
      "12/12 [==============================] - 2s 166ms/step - loss: 8454.0918\n",
      "Epoch 6/10\n",
      "12/12 [==============================] - 2s 167ms/step - loss: 8382.4600\n",
      "Epoch 7/10\n",
      "12/12 [==============================] - 2s 163ms/step - loss: 8313.3584\n",
      "Epoch 8/10\n",
      "12/12 [==============================] - 2s 161ms/step - loss: 8245.1660\n",
      "Epoch 9/10\n",
      "12/12 [==============================] - 2s 160ms/step - loss: 8180.5093\n",
      "Epoch 10/10\n",
      "12/12 [==============================] - 2s 155ms/step - loss: 8118.5229\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x28f31d630>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "PATH_TO_DATASET = 'your/path/to/dataset/'\n",
    "\n",
    "# create the model with transe scoring function\n",
    "partitioned_model = ScoringBasedEmbeddingModel(eta=2, \n",
    "                                               k=50, \n",
    "                                               scoring_type='TransE')\n",
    "partitioned_model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "# Here we have specified the path of the input file\n",
    "# you can also load using default dataloaders load_fb15k_237() and pass numpy array inputs\n",
    "partitioned_model.fit(PATH_TO_DATASET + 'wn18RR/train.txt',\n",
    "                      batch_size=10000, \n",
    "                      partitioning_k=3, # set flag to partition the inputs\n",
    "                      epochs=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "24cb1627",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "9/9 [==============================] - 14s 2s/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(20079.140731874144, 0.011132840015629617, 0.0, 0.03625170998632011, 2924)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unfiltered evaluation\n",
    "ranks = partitioned_model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', \n",
    "                                   batch_size=400)\n",
    "\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "badb504e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "9/9 [==============================] - 68s 8s/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(20066.594562243503,\n",
       " 0.01583735697421522,\n",
       " 0.005471956224350205,\n",
       " 0.038132694938440494,\n",
       " 2924)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Filtered evaluation\n",
    "ranks = partitioned_model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', \n",
    "                        batch_size=400,\n",
    "                        corrupt_side='s,o',\n",
    "                        use_filter={'train': PATH_TO_DATASET + 'wn18RR/train.txt',\n",
    "                                    'valid': PATH_TO_DATASET + 'wn18RR/valid.txt',\n",
    "                                    'test': PATH_TO_DATASET + 'wn18RR/test.txt'})\n",
    "\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f8d55c73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The path ./partitioned_model already exists. This save operation will overwrite the model                 at the specified path.\n",
      "WARNING - Found untraced functions such as _get_ranks while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import save_model\n",
    "save_model(model=partitioned_model, model_name_path='./partitioned_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d93c7719",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved model does not include a db file. Skipping.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import restore_model\n",
    "model = restore_model('./partitioned_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "00fec833",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "9/9 [==============================] - 15s 2s/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(20079.140731874144, 0.011132840015629617, 0.0, 0.03625170998632011, 2924)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unfiltered evaluation\n",
    "ranks = model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt',\n",
    "                       batch_size=400)\n",
    "\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "750bdfd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "\n",
      "210 triples containing invalid keys skipped!\n",
      "9/9 [==============================] - 73s 8s/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(20066.594562243503,\n",
       " 0.01583735697421522,\n",
       " 0.005471956224350205,\n",
       " 0.038132694938440494,\n",
       " 2924)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ranks = model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', \n",
    "                        batch_size=400,\n",
    "                        corrupt_side='s,o',\n",
    "                        use_filter={'train': PATH_TO_DATASET + 'wn18RR/train.txt',\n",
    "                                    'valid': PATH_TO_DATASET + 'wn18RR/valid.txt',\n",
    "                                    'test': PATH_TO_DATASET + 'wn18RR/test.txt'})\n",
    "\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
